#include <torch/extension.h>

#include "forward.h"
#include "backward.h"
#include "cuda_ops.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("cuda_add", &cuda_add, "CUDA add two tensors");
  m.def("p2e_logproba_forward", &p2e_logproba_forward, "p2e logproba forward (CUDA)");
  m.def("p2e_logproba_backward", &p2e_logproba_backward, "p2e logproba backward (CUDA)");
}